# -*- coding: utf-8 -*-
"""
This code achieves the rate based electrode health predictions and 
residual compensation for different C-rates using Dataset 1
"""

import numpy as np
from numpy import gradient
import torch
from sklearn.model_selection import train_test_split
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import random
import seaborn as sns
from matplotlib import rcParams
import pandas as pd
from sklearn import metrics
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.colors as mcolors
import time
from scipy.signal import savgol_filter
from sklearn.model_selection import KFold
from scipy.stats import linregress
from matplotlib.lines import Line2D
from sklearn.metrics import mean_squared_error
from matplotlib.colors import LinearSegmentedColormap
from sklearn.linear_model import MultiTaskLasso
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import joblib
import warnings
warnings.filterwarnings("ignore")

def set_random_seed(seed):
    # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})


#%%


nominal_capacity = 4.84

# c/5 
OCPn_data = pd.read_csv('anode_SiO_Gr_discharge_Cover5_smoothed_dvdq_JS.csv')
OCPp_data = pd.read_csv('cathode_NCA_discharge_Cover5_smoothed_dvdq_JS.csv')

OCPn_SOC = OCPn_data['SOC_linspace'].values
OCPn_V = OCPn_data['Voltage'].values
OCPp_SOC = OCPp_data['SOC_linspace'].values
OCPp_V = OCPp_data['Voltage'].values[::-1].copy()  # 

OCP_p = interp1d(OCPp_SOC, OCPp_V, kind='cubic', fill_value='extrapolate', bounds_error=False)
OCP_n = interp1d(OCPn_SOC, OCPn_V, kind='cubic', fill_value='extrapolate', bounds_error=False)


# c/40 
OCPn_data_40 = pd.read_csv('anode_SiO_Gr_discharge_Cover40_smooth_JS.csv')
OCPp_data_40 = pd.read_csv('cathode_NCA_discharge_Cover40_smooth_JS.csv')

OCPn_SOC_40 = OCPn_data_40['SOC_linspace'].values
OCPn_V_40 = OCPn_data_40['Voltage'].values
OCPp_SOC_40 = OCPp_data_40['SOC_linspace'].values
OCPp_V_40 = OCPp_data_40['Voltage'].values[::-1].copy()

OCP_p_40 = interp1d(OCPp_SOC_40, OCPp_V_40, kind='cubic', fill_value='extrapolate', bounds_error=False)
OCP_n_40 = interp1d(OCPn_SOC_40, OCPn_V_40, kind='cubic', fill_value='extrapolate', bounds_error=False)

#%%
plt.rcParams['font.family'] = 'Times New Roman'
rcParams['mathtext.fontset'] = 'custom'
rcParams['mathtext.rm'] = 'Times New Roman'
rcParams['mathtext.it'] = 'Times New Roman:italic'
rcParams['mathtext.bf'] = 'Times New Roman:bold'
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['axes.edgecolor'] = 'black'
plt.rcParams['axes.linewidth'] = 0.8
plt.rcParams['font.size'] = 12

opt_funcs_options = ['DE'] #'PSO','DE','GA','CMA-ES','BO'
dofs = [3] #[ 2,3,4]
object_losses = ['eucl']#'eucl','mse','dvf','eucl_mse','eucl_dvf','dvf_eucl','eucl_mse_dvf'
for object_loss in ['eucl']: #object_losses
    
    for dof in dofs:
    
        for opt_func_trial in opt_funcs_options:
            data_sets_all=[]
            filenames=[f"saved_fittings/resval_extract_data_{opt_func_trial}_DOF{dof}_{object_loss}.npz",
                       f"saved_fittings/resval_extract_data_{opt_func_trial}_DOF{dof}_{object_loss}_reverse.npz"]
            for filename in filenames:
                
                print(filename)
                norminal_c = 4.84
                data = np.load(filename, allow_pickle=True)
                all_Cq = data['all_Cq']
                all_Cp_opt = data['all_Cp_opt']*norminal_c
                all_Cn_opt = data['all_Cn_opt']*norminal_c
                all_x0_opt = data['all_x0_opt']
                all_y0_opt = data['all_y0_opt']
                
                all_cells = data['all_cells']
                all_fit_results = data['all_fit_results']
                data_dict = {key: data[key] for key in data.files}
                # print("Keys in the file:", data.files)
                all_cells = data['all_cells']  
                cell_names = all_cells[:, 0]  
                rate_labels = all_cells[:, 1]  
                unique_labels = np.unique(rate_labels)
                
                split_data_dict = {label: {} for label in unique_labels}
                
                for key in data_dict:
                    if key!='time_consum':
                        if data_dict[key].shape[0] == len(rate_labels):  
                            for label in unique_labels:
                                split_data_dict[label][key] = data_dict[key][rate_labels == label]
                        else:
                            for label in unique_labels:
                                split_data_dict[label][key] = data_dict[key]
            
                data_sets = []
                # ['C/40', 'C/5']
                assert len(unique_labels) == 2, "C/5 and C/40, need to change to suit more C-rates"
                
                label1, label2 = unique_labels  
                subset1, subset2 = split_data_dict[label1], split_data_dict[label2]
                
                label1, label2 = 'C/40', 'C/5'
               
                data_sets = [
                    (subset1['all_Cp_opt']*norminal_c, subset2['all_Cp_opt']*norminal_c,
                     r'${\mathrm{C_p}}$@' + label1 + ' [Ah]', r'${\mathrm{C_p}}$@' + label2 + ' [Ah]'),
                
                    (subset1['all_Cn_opt']*norminal_c, subset2['all_Cn_opt']*norminal_c,
                     r'${\mathrm{C_n}}$@' + label1 + ' [Ah]', r'${\mathrm{C_n}}$@' + label2 + ' [Ah]'),
                
                    (subset1['all_Cn_opt'] * subset1['all_x0_opt']*norminal_c + subset1['all_Cp_opt'] * subset1['all_y0_opt']*norminal_c,
                     subset2['all_Cn_opt'] * subset2['all_x0_opt']*norminal_c + subset2['all_Cp_opt'] * subset2['all_y0_opt']*norminal_c,
                     r'${\mathrm{Q_{li}}}$@' + label1 + ' [Ah]', r'${\mathrm{Q_{li}}}$@' + label2 + ' [Ah]')
                ]
                
                data_sets_all.append(data_sets)
                
                fig, axs = plt.subplots(1, 3, figsize=(22 / 2.54, 7 / 2.54), dpi=600)
                plt.rcParams['xtick.direction'] = 'in'
                plt.rcParams['ytick.direction'] = 'in'
                plt.tick_params(top='on', right='on', which='both')
                plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
                color_map = plt.get_cmap("coolwarm")
                
                for i, (x, y, xlabel, ylabel) in enumerate(data_sets):
                    fill_styles = []
                    marker_list = []
                    for k in range(len(x)):
                        color_value = ((x[k] / norminal_c) - min(x) / norminal_c) * 3 #min(x) / norminal_c
                        color = color_map(color_value)
                        filled_marker_style = dict(
                            marker="o",
                            linestyle="none",
                            markersize=10,
                            color=color,
                            markerfacecolor=color,
                            markerfacecoloralt=color,
                            markeredgecolor="black",
                        )
                        style = Line2D.fillStyles[k % (len(Line2D.fillStyles) - 1)]
                        fill_styles.append(style)
                        marker_list.append(filled_marker_style)
                        
                    for j in range(len(x)):
                        axs[i].plot(x[j], y[j], **marker_list[j], fillstyle=fill_styles[j])
                    # axs[i].scatter(x, y, color='grey', linewidth=2)
                    if np.allclose(x, x[0]):
                        axs[i].text(0.1, 0.9, "All x identical", transform=axs[i].transAxes, fontsize=12, color='red')
                        continue
                    slope, intercept, r_value, _, _ = linregress(x, y)
                    reg_x = np.linspace(min(x), max(x), 100)
                    reg_y = slope * reg_x + intercept
                    
                    axs[i].plot(reg_x, reg_y, 'k-',linewidth=2)
                    axs[i].text(0.05, 0.9, f"$R$ = {r_value:.3f}", transform=axs[i].transAxes, fontsize=12, color='black')
                    axs[i].set_xlabel(xlabel)
                    axs[i].set_ylabel(ylabel)
                    axs[i].tick_params(axis='both', which='both', bottom=True, top=False, left=True, right=False)
               
                plt.tight_layout()
                plt.show()

#%%
colors = {'C/40': '#B283B9', 'C/5': '#D5CA80'} #'#D5CA80','#9593C3'
markers = {'C/40': 'o', 'C/5': 's'}
rates = ['C/40', 'C/5']

forward_data = data_sets_all[0]  # 
reverse_data = data_sets_all[1]  # 

fig, axs = plt.subplots(1, 3, figsize=(22 / 2.54, 6.5 / 2.54), dpi=600)
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=True, right=False)
xlabels =[r'${\mathrm{C_p}}$ [Ah]',r'${\mathrm{C_n}}$ [Ah]',r'${\mathrm{Q_{li}}}$ [Ah]']
ylabels =[r'${\mathrm{C_p}}$ [Ah]',r'${\mathrm{C_n}}$ [Ah]',r'${\mathrm{Q_{li}}}$ [Ah]']
for i, ((fwd_x40, fwd_x5, xlabel, _), (rev_x40, rev_x5, _, ylabel)) in enumerate(zip(forward_data, reverse_data)):
    
    axs[i].scatter(rev_x40, fwd_x40, color=colors['C/40'], marker=markers['C/40'], alpha=0.8, linewidth=1, label='C/40')
    slope, intercept, r_value, _, _ = linregress(rev_x40, fwd_x40)
    reg_x = np.linspace(np.min(rev_x40), np.max(rev_x40), 100)
    reg_y = slope * reg_x + intercept
    axs[i].plot(reg_x, reg_y, color=colors['C/40'], linestyle='--', linewidth=1.5,label=f'$R$={r_value:.3f}')

    axs[i].scatter(rev_x5, fwd_x5, color=colors['C/5'], marker=markers['C/5'], alpha=0.8, linewidth=1,label='C/5')
    slope, intercept, r_value, _, _ = linregress(rev_x5, fwd_x5)
    reg_x = np.linspace(np.min(rev_x5), np.max(rev_x5), 100)
    reg_y = slope * reg_x + intercept
    axs[i].plot(reg_x, reg_y, color=colors['C/5'], linestyle='--', linewidth=1.5, label=f'$R$={r_value:.3f}')

    axs[i].set_xlabel('Mismatched ' + xlabels[i])
    axs[i].set_ylabel(ylabels[i])
    axs[i].legend(
                  loc='upper left',
                  handletextpad=0.1, 
                  labelspacing=0.05,
                  bbox_to_anchor=(0, 1.05),
                  frameon=False,
                  fontsize=10)
    # axs[i].grid(True)
    # axs[i].set_aspect('equal', adjustable='box')  #

plt.tight_layout()
plt.show()


#%%
filename = f"saved_fittings/resval_extract_data_DE_DOF3_eucl.npz"
print(filename)
norminal_c = 4.84
data = np.load(filename, allow_pickle=True)
all_Cq = data['all_Cq']*norminal_c
all_Cp_opt = data['all_Cp_opt']*norminal_c
all_Cn_opt = data['all_Cn_opt']*norminal_c
all_x0_opt = data['all_x0_opt']
all_y0_opt = data['all_y0_opt']

data_dict = {key: data[key] for key in data.files}
# print("Keys in the file:", data.files)
all_cells = data['all_cells']  
cell_names = all_cells[:, 0]  
rate_labels = all_cells[:, 1]  
unique_labels = np.unique(rate_labels)

split_data_dict = {label: {} for label in unique_labels}

for key in data_dict:
    if key!='time_consum':
        if data_dict[key].shape[0] == len(rate_labels):  
            for label in unique_labels:
                split_data_dict[label][key] = data_dict[key][rate_labels == label]
        else:
            for label in unique_labels:
                split_data_dict[label][key] = data_dict[key]

data_sets = []
# ['C/40', 'C/5']
assert len(unique_labels) == 2, "C/5 and C/40, need to change to suit more C-rates"

label1, label2 = unique_labels  
subset1, subset2 = split_data_dict[label1], split_data_dict[label2]

data_sets = [
    (subset1['all_Cq']*norminal_c, subset2['all_Cq']*norminal_c, f'C {label1}', f'C {label2}'),
    (subset1['all_Cp_opt']*norminal_c, subset2['all_Cp_opt']*norminal_c, f'Cp {label1}', f'Cp {label2}'),
    (subset1['all_Cn_opt']*norminal_c, subset2['all_Cn_opt']*norminal_c, f'Cn {label1}', f'Cn {label2}'),
    (subset1['all_Cn_opt'] * subset1['all_x0_opt']*norminal_c + subset1['all_Cp_opt'] * subset1['all_y0_opt']*norminal_c, 
     subset2['all_Cn_opt'] * subset2['all_x0_opt']*norminal_c + subset2['all_Cp_opt'] * subset2['all_y0_opt']*norminal_c, 
     f'Cli {label1}', f'Cli {label2}')
]



np.random.seed(123)
seeds = np.random.choice(9999, size=50, replace=False).tolist()

X = np.column_stack([
    data_sets[0][1],  # Cq @ C/5
    data_sets[1][1],  # Cp @ C/5
    data_sets[2][1],  # Cn @ C/5
    data_sets[3][1],  # Cli @ C/5
])
Y = np.column_stack([
    data_sets[0][0],  # Cq @ C/40
    data_sets[1][0],  # Cp @ C/40
    data_sets[2][0],  # Cn @ C/40
    data_sets[3][0],  # Cli @ C/40
])

# mask = (Y[:, 0] / 4.84) >= 0.7
# X = X[mask]
# Y = Y[mask]
colors = ['#DB6C6E', '#7BABD2', '#B3C786', '#B283B9','#F4BA61','#D5CA80','#9593C3']

# output_names = ["Cq", "Cp", "Cn", "Qli"]
output_names =[r'${\mathrm{C_q}}$',r'${\mathrm{C_p}}$',r'${\mathrm{C_n}}$',r'${\mathrm{Q_{li}}}$']
alphas = np.logspace(-2, 2, 50)

all_mae_sums = []
all_mae_per_output = []
all_rmse_per_output = []
all_model_records = []

for seed in seeds:
    X_train_full, X_test, Y_train_full, Y_test = train_test_split(X, Y, test_size=0.2, random_state=seed)
    X_train, X_val, Y_train, Y_val = train_test_split(X_train_full, Y_train_full, test_size=0.25, random_state=seed)

    scaler_X = StandardScaler().fit(X_train)
    scaler_Y = StandardScaler().fit(Y_train)

    X_train_scaled = scaler_X.transform(X_train)
    X_val_scaled = scaler_X.transform(X_val)
    X_test_scaled = scaler_X.transform(X_test)
    Y_train_scaled = scaler_Y.transform(Y_train)
    Y_val_scaled = scaler_Y.transform(Y_val)
    Y_test_scaled = scaler_Y.transform(Y_test)

    best_alpha, best_score = None, float('inf')
    for alpha in alphas:
        model = MultiTaskLasso(alpha=alpha)
        model.fit(X_train_scaled, Y_train_scaled)
        Y_val_pred = scaler_Y.inverse_transform(model.predict(X_val_scaled))
        Y_val_true = scaler_Y.inverse_transform(Y_val_scaled)
        rmse = np.sqrt(((Y_val_pred - Y_val_true) ** 2).mean())
        if rmse < best_score:
            best_alpha = alpha
            best_score = rmse

    X_final = np.vstack([X_train, X_val])
    Y_final = np.vstack([Y_train, Y_val])
    scaler_X = StandardScaler().fit(X_final)
    scaler_Y = StandardScaler().fit(Y_final)
    X_final_scaled = scaler_X.transform(X_final)
    Y_final_scaled = scaler_Y.transform(Y_final)

    final_model = MultiTaskLasso(alpha=best_alpha)
    final_model.fit(X_final_scaled, Y_final_scaled)

    Y_test_pred_scaled = final_model.predict(scaler_X.transform(X_test))
    Y_test_pred = scaler_Y.inverse_transform(Y_test_pred_scaled)
    Y_test_true = scaler_Y.inverse_transform(Y_test_scaled)

    metrics_per_output = []
    
    mae_sum = 0
    for i in range(4):
        y_t = Y_test_true[:, i]/4.84
        y_p = Y_test_pred[:, i]/4.84
        rmse = mean_squared_error(y_t, y_p, squared=False)
        mae = mean_absolute_error(y_t, y_p)
        r2 = r2_score(y_t, y_p)
        metrics_per_output.append((rmse, mae, r2))
        mae_sum += mae

    all_mae_sums.append(mae_sum)
    all_mae_per_output.append([m[1] for m in metrics_per_output])
    all_rmse_per_output.append([m[0] for m in metrics_per_output])
    all_model_records.append({
        "model": final_model,
        "scaler_X": scaler_X,
        "scaler_Y": scaler_Y,
        "X_test": X_test,
        "Y_true": Y_test_true,
        "Y_pred": Y_test_pred,
        "metrics": metrics_per_output,
        "mae_sum": mae_sum
    })

#  best/median/worst
idx_sorted = np.argsort(all_mae_sums)
idx_best = idx_sorted[0]
idx_median = idx_sorted[len(idx_sorted)//2]
idx_worst = idx_sorted[-1]
index_dict = {'Best': idx_best, 'Median': idx_median, 'Worst': idx_worst}

# 
for label, idx in index_dict.items():
    record = all_model_records[idx]
    model = record["model"]
    scaler_X = record["scaler_X"]
    scaler_Y = record["scaler_Y"]
    metrics = record["metrics"]
    print(f"Results for {label}:")
    for i, name in enumerate(output_names):
        fig, axs = plt.subplots(1, 1, figsize=(7.5/ 2.54, 6 / 2.54), dpi=600)
        Y_train_true = scaler_Y.inverse_transform(Y_train_scaled)[:, i]
        Y_val_true   = scaler_Y.inverse_transform(Y_val_scaled)[:, i]
        Y_test_true  = record["Y_true"][:, i]

        Y_train_pred = scaler_Y.inverse_transform(model.predict(scaler_X.transform(X_train)))[:, i]
        Y_val_pred   = scaler_Y.inverse_transform(model.predict(scaler_X.transform(X_val)))[:, i]
        Y_test_pred  = record["Y_pred"][:, i]

        plt.scatter(Y_train_true, Y_train_pred, c=colors[0], marker='o', label='Train', alpha=0.6)
        plt.scatter(Y_val_true,   Y_val_pred,   c=colors[4], marker='^', label='Validation', alpha=0.6)
        plt.scatter(Y_test_true,  Y_test_pred,  c=colors[1], marker='s', label='Test', alpha=0.9)

        all_true = np.concatenate([Y_train_true, Y_val_true, Y_test_true])
        plt.plot([all_true.min(), all_true.max()],
                 [all_true.min(), all_true.max()], '--', c=colors[3], label='Ideal')

        rmse, mae, r2 = metrics[i]
        # print(f"{name} ({label})\nRMSE: {rmse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")
        
        print(f"{name:>4s} → RMSE: {rmse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")
        plt.xlabel(f"Fitted {name} @ C/40")
        plt.ylabel(f"Predicted {name} @ C/40")
        # plt.title(f"{name} ({label})\nRMSE: {rmse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")
        # plt.grid(True)
        plt.legend(loc='upper left',
                   handletextpad=0.1, 
                   labelspacing=0.05,
                   frameon=False,
                   fontsize=10)
        # plt.legend()
        plt.tight_layout()
        plt.show()

    colors = ['#DB6C6E', '#7BABD2', '#B3C786', '#B283B9','#F4BA61','#D5CA80','#9593C3']
    custom_cmap = LinearSegmentedColormap.from_list("colors", ['#7BABD2', '#B3C786'])
    fig, axs = plt.subplots(1, 1, figsize=(7/ 2.54, 6 / 2.54), dpi=600)
    sns.heatmap(model.coef_, annot=True, alpha=0.7,
                cmap=custom_cmap,
                # xticklabels=algorithms,
                # yticklabels=[f"{d}" for d in dofs],
                annot_kws={"color": "black","size": 9},
                cbar_kws={"shrink": 0.9},
                xticklabels=output_names,
                yticklabels=output_names)
    # plt.title(f"{label} Model Coefficient Matrix")
    plt.xlabel("Inputs @ C/5")
    plt.ylabel("Outputs @ C/40")
    plt.tight_layout()
    plt.show()

mae_array = np.array(all_mae_per_output)  # shape: [n_seeds, 4]
rmse_array = np.array(all_rmse_per_output)  # shape: [n_seeds, 4]
rmse_df = pd.DataFrame(rmse_array, columns=output_names)
mean_rmses = rmse_df.mean()
markers = ['o', '^']
metric_labels = ['MAE', 'RMSE']

fig, ax = plt.subplots(figsize=(9 / 2.54, 6 / 2.54), dpi=600)

x_positions = np.arange(len(output_names)) + 1
offset = 0.2
mean_points = {0: [], 1: []}  


for i, name in enumerate(output_names):
    for metric_idx, data_array in enumerate([mae_array, rmse_array]):
        x = x_positions[i] + (-1)**metric_idx * offset
        data = data_array[:, i]
        mean_val = np.mean(data)
        mean_points[metric_idx].append((x, mean_val))
        vp = ax.violinplot([data], positions=[x], widths=0.3,
                           showmeans=False, showmedians=True, showextrema=False)
        for pc in vp['bodies']:
            pc.set_facecolor(colors[metric_idx])
            pc.set_edgecolor(colors[metric_idx])
            pc.set_alpha(0.4)
        vp['cmedians'].set_color(colors[metric_idx])
        vp['cmedians'].set_linewidth(1)
        x_jittered = np.random.normal(x, 0.03, size=len(data))
        ax.scatter(x_jittered, data,
                   color=colors[metric_idx],
                   marker=markers[metric_idx],
                   alpha=0.7,
                   s=10,
                   edgecolors='white',
                   linewidths=0.0)

        ax.scatter(x, mean_val,
                   color='grey',
                   edgecolor=colors[metric_idx],
                   marker='s',
                   s=15,
                   zorder=4,
                   linewidth=1)

for metric_idx in [0, 1]:
    xs = [pt[0] for pt in mean_points[metric_idx]]
    ys = [pt[1] for pt in mean_points[metric_idx]]
    ax.plot(xs, ys, '--', color='gray', linewidth=1, alpha=0.8)


ax.set_xticks(x_positions)
ax.set_xticklabels(output_names)
ax.set_ylabel("Error Value")
# ax.set_title("Test Error Distribution across 50 Runs (MAE vs RMSE)", fontsize=9)

legend_elements = [
    Line2D([0], [0], marker='o', color=colors[0], label='MAE', markersize=3, alpha=0.7,linestyle='None'),
    Line2D([0], [0], marker='^', color=colors[1], label='RMSE', markersize=3, alpha=0.7,linestyle='None'),
    Line2D([0], [0], marker='s', color='gray', label='Mean', linestyle='--', markersize=3)
]
ax.legend(handles=legend_elements, loc='upper right', fontsize=10, handletextpad=0.1, labelspacing=0.05,frameon=False)

plt.tight_layout()
plt.ylim([0,0.065])
plt.show()



np.random.seed(123)
alphas = np.logspace(-2, 2, 50)
kf = KFold(n_splits=5, shuffle=True, random_state=42)
X = np.column_stack([
    data_sets[0][1],  # Cq @ C/5
    data_sets[1][1],  # Cp @ C/5
    data_sets[2][1],  # Cn @ C/5
    data_sets[3][1],  # Cli @ C/5
])
Y = np.column_stack([
    data_sets[0][0],  # Cq @ C/40
    data_sets[1][0],  # Cp @ C/40
    data_sets[2][0],  # Cn @ C/40
    data_sets[3][0],  # Cli @ C/40
])

colors = ['#DB6C6E', '#7BABD2', '#B3C786', '#B283B9','#F4BA61','#D5CA80','#9593C3']

# output_names = ["Cq", "Cp", "Cn", "Qli"]
output_names =[r'${\mathrm{C_q}}$',r'${\mathrm{C_p}}$',r'${\mathrm{C_n}}$',r'${\mathrm{Q_{li}}}$']
alphas = np.logspace(-2, 2, 50)

best_alphas = []

for train_index, val_index in kf.split(X):
    X_train, X_val = X[train_index], X[val_index]
    Y_train, Y_val = Y[train_index], Y[val_index]
    
    scaler_X = StandardScaler().fit(X_train)
    scaler_Y = StandardScaler().fit(Y_train)
    X_train_scaled = scaler_X.transform(X_train)
    X_val_scaled = scaler_X.transform(X_val)
    Y_train_scaled = scaler_Y.transform(Y_train)
    Y_val_scaled = scaler_Y.transform(Y_val)

    best_alpha, best_score = None, float('inf')
    for alpha in alphas:
        model = MultiTaskLasso(alpha=alpha)
        model.fit(X_train_scaled, Y_train_scaled)
        Y_val_pred = scaler_Y.inverse_transform(model.predict(X_val_scaled))
        Y_val_true = scaler_Y.inverse_transform(Y_val_scaled)
        rmse = np.sqrt(((Y_val_pred - Y_val_true) ** 2).mean())
        if rmse < best_score:
            best_alpha = alpha
            best_score = rmse
    
    best_alphas.append(best_alpha)
    print(f"Fold best alpha: {best_alpha:.4f}, RMSE: {best_score:.4f}")

final_alpha = np.mean(best_alphas)
# final_alpha =0.01
print(f"\nFinal averaged alpha from 5 folds: {final_alpha:.4f}")

scaler_X_full = StandardScaler().fit(X)
scaler_Y_full = StandardScaler().fit(Y)
X_scaled = scaler_X_full.transform(X)
Y_scaled = scaler_Y_full.transform(Y)

final_model = MultiTaskLasso(alpha=final_alpha)
final_model.fit(X_scaled, Y_scaled)

print("Final model coefficients:")
print(final_model.coef_)
joblib.dump(final_model, 'saved_fittings/'+'electrode_C5_to_C40.pkl')

#%%

opt_func_trial = 'DE'
dof = 3
norminal_c = 4.84
# object_losses = ['mse', 'eucl', 'dvf', 'eucl_dvf']
colors = ['#DB6C6E', '#7BABD2', '#B3C786', '#B283B9','#F4BA61','#D5CA80','#9593C3']
object_losses = ['eucl']
filename = f"saved_fittings/resval_extract_data_{opt_func_trial}_DOF{dof}_{object_loss}.npz"
data = np.load(filename, allow_pickle=True)
all_Cp_opt = data["all_Cp_opt"]
all_Cn_opt = data["all_Cn_opt"]
all_x0_opt = data["all_x0_opt"]
all_y0_opt = data["all_y0_opt"]
all_Cq = data["all_Cq"]
all_OCV_fit = data["all_OCV_fit"]
all_cell_cap = data["all_cell_cap"]
all_cell_ocv = data["all_cell_ocv"]
all_cell_vmea = data["all_cell_vmea"]
all_cells = data["all_cells"]

all_Cli = np.array([
    np.array(cp) * np.array(y0) + np.array(x0) * np.array(cn)
    for cp, y0, x0, cn in zip(all_Cp_opt, all_y0_opt, all_x0_opt, all_Cn_opt)
])


all_predictions = []  

for i in range(94,len(all_Cp_opt)):
    X_new = np.column_stack([
        np.array(all_Cq[i]) * norminal_c,
        np.array(all_Cp_opt[i]) * norminal_c,
        np.array(all_Cn_opt[i]) * norminal_c,
        np.array(all_Cli[i]) * norminal_c,
    ])
    X_new_scaled = scaler_X_full.transform(X_new)
    Y_new_scaled_pred = final_model.predict(X_new_scaled)
    Y_new_pred = scaler_Y_full.inverse_transform(Y_new_scaled_pred)
    all_predictions.append(Y_new_pred)  # shape: (n_RPT, 4)

all_cell_ocv_construct = 0*all_cell_ocv[0:94]
for idx in range(len(all_cell_ocv_construct)):
    Cq = all_predictions[idx][0][0] /nominal_capacity
    Cp = all_predictions[idx][0][1] /nominal_capacity
    Cn = all_predictions[idx][0][2] /nominal_capacity
    Cli = all_predictions[idx][0][3] /nominal_capacity
    y0 = 0
    x0 = Cli / Cn

    # measured_Q：0 ~ -Cq，1000
    measured_Q = np.linspace(0, Cq, 1000)
    
    SOC_p = y0 + measured_Q / Cp
    SOC_n = x0 - measured_Q / Cn
    Up = OCP_p_40(SOC_p)
    Un = OCP_n_40(SOC_n)
    Voc_fit = Up - Un 
    all_cell_ocv_construct[idx,:,0]=Voc_fit
    all_cell_ocv_construct[idx,:,1]=measured_Q

#%%

all_Cq = data['all_Cq']*norminal_c
all_OCV_fit = data['all_OCV_fit']
all_cell_ocv = data['all_cell_ocv']
all_cell_vmea = data['all_cell_vmea']
all_cells = data['all_cells']
all_fit_results = data['all_fit_results']
all_cell_Vreal = all_cell_ocv[:,:,0]*4.2
all_cell_Qreal = all_cell_ocv[:,:,1]
all_cell_Vm = all_cell_vmea[:,:,0]*4.2
all_cell_Qm = all_cell_vmea[:,:,1]
all_cell_Vconstruct = all_cell_ocv_construct[:,:,0]

all_cell_Vreal = savgol_filter(all_cell_Vreal,window_length=5,polyorder=1)
all_cell_Vconstruct = savgol_filter(all_cell_Vconstruct,window_length=5,polyorder=1)
all_cell_Vm = savgol_filter(all_cell_Vm,window_length=5,polyorder=1)
all_OCV_fit = savgol_filter(all_OCV_fit,window_length=5,polyorder=1)


all_v_diff_C5 = all_cell_Vm[94:]-all_OCV_fit[94:]
all_v_diff_C40= all_cell_Vreal[0:94]-all_OCV_fit[0:94]
all_v_diff_C40_construct = all_cell_Vreal[0:94]-all_cell_Vconstruct
all_q_diff = all_cell_Qreal*norminal_c-all_cell_Qm*norminal_c


linestyles = ['--','-.','--',':',':','--','-.']
fig, axs = plt.subplots(1, 1, figsize=(8/ 2.54, 6 / 2.54), dpi=600)
for i in range(94):
    if all_Cq[i]/norminal_c<0.7:
        continue
    axs.plot(all_Cq[i], np.mean(abs(all_v_diff_C40[i]))*1000, 'o', color=colors[0],alpha=0.8,label='C/40' if i==0 else None)
    axs.plot(all_Cq[i+94], np.mean(abs(all_v_diff_C5[i]))*1000, '^',color=colors[1], alpha=0.8,label='C/5' if i==0 else None)
    # axs.plot(all_cell_ocv_construct[0:94,-1,1]*norminal_c, np.mean(abs(all_v_diff_C40_construct),axis=1)*1000, 's',color=colors[2], alpha=0.8,label='C/5')

axs.set_xlabel("Q [Ah]")
axs.set_ylabel("Mean fitted error [mV]")
plt.ylim([-1,35])
# plt.xlim([3,5])
axs.legend(loc='best',
          handletextpad=0.1, 
          labelspacing=0.05,
          frameon=False)
plt.show()


fig, axs = plt.subplots(1, 1, figsize=(8/ 2.54, 6 / 2.54), dpi=600)
for i in range(94):
    if all_Cq[i]/norminal_c<0.7:
        continue
    axs.plot(all_cell_Qreal[i]/max(all_cell_Qreal[i]), all_v_diff_C40[i]*1000, '-', color=colors[0],alpha=0.8,label='C/40' if i==0 else None)
    axs.plot(all_cell_Qm[i+94]/max(all_cell_Qm[i+94]), all_v_diff_C5[i]*1000, '--',color=colors[1], alpha=0.8,label='C/5' if i==0 else None)
    # axs.plot(all_cell_ocv_construct[0:94,-1,1]*norminal_c, np.mean(abs(all_v_diff_C40_construct),axis=1)*1000, 's',color=colors[2], alpha=0.8,label='C/5')

axs.set_xlabel("Depth of discharge")
axs.set_ylabel("Fitted error [mV]")
plt.ylim([-120,120])
# plt.xlim([3,5])
axs.legend(loc='best',
          handletextpad=0.1, 
          labelspacing=0.05,
          frameon=False)
plt.show()

#%%

# real_OCV = all_cell_Vreal[:94]
# fit_OCV = all_OCV_fit[:94]
# Measure_Q = all_cell_Qreal[:94]
# Measure_V = all_cell_Vreal[:94]


real_OCV = all_cell_Vm[94:,:]
fit_OCV = np.array(all_OCV_fit[94:,:])
Measure_Q = all_cell_Qm[94:,:]
Measure_V = all_cell_Vm[94:,:]


residual = real_OCV - fit_OCV  # shape: (94, 1000)
valid_indices = [i for i in range(94) if all_Cq[i] >= 0.7 * norminal_c]

model_name = 'final_model_C5_train1.pkl'

print('Train for:',model_name)

X_features = []
Y_targets = []
for i in valid_indices:
    voc_fit = fit_OCV[i,:]
    voc_real = real_OCV[i,:]
    q_meas = Measure_Q[i,:]
    v_meas = Measure_V[i,:]
    
    cp = all_Cp_opt[i]
    cn = all_Cn_opt[i]
    cq = all_Cq[i]
    cli = all_Cli[i]  # 
    # for j in range(len(voc_fit)):
    #     X_i = np.column_stack([
    #         voc_fit[j],
    #         q_meas[j],
    #         v_meas[j],
    #         cp,
    #         cn,
    #         cli,
    #     ])
    #     X_features.append(X_i.flatten())  # 
    #     Y_targets.append((voc_real[j] - voc_fit[j]))  #

    X_i = np.column_stack([
        voc_fit,
        # q_meas,
        # v_meas,
        np.full_like(voc_fit, cp),
        np.full_like(voc_fit, cn),
        np.full_like(voc_fit, cli),
    ])
    
    #### flatten 
    X_features.append(X_i.flatten())  # shape: (6000,)
    Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
    
    # plt.figure(figsize=(10 / 2.54,6 / 2.54), dpi=600)
    # plt.ion()
    # plt.rcParams['xtick.direction'] = 'in'
    # plt.rcParams['ytick.direction'] = 'in'
    # plt.tick_params(top='on', right='on', which='both')
    # plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    # plt.plot(q_meas,voc_fit)
    # plt.plot(q_meas,voc_real)
    # plt.show()
    
    

X_all = np.stack(X_features)  # shape: (n_samples, 6000)
Y_all = np.stack(Y_targets)   # shape: (n_samples, 1000)
print(X_all.shape)

# 6:2:2
# X_temp, X_test, Y_temp, Y_test = train_test_split(X_all, Y_all, test_size=0.2, random_state=42)
# X_train, X_val, Y_train, Y_val = train_test_split(X_temp, Y_temp, test_size=0.25, random_state=42)


# use for train all
X_train, X_val, Y_train, Y_val = train_test_split(X_all, Y_all, test_size=0.2, random_state=42)
X_test, Y_test = X_val, Y_val


print(f"Train size: {X_train.shape[0]}, Val size: {X_val.shape[0]}, Test size: {X_test.shape[0]}")


scaler_X = StandardScaler().fit(X_train)
scaler_Y = StandardScaler().fit(Y_train)

X_train_std = scaler_X.transform(X_train)
X_val_std = scaler_X.transform(X_val)
X_test_std = scaler_X.transform(X_test)

Y_train_std = scaler_Y.transform(Y_train)
Y_val_std = scaler_Y.transform(Y_val)

alphas = np.logspace(-3, 1, 30)
best_alpha = None
best_val_rmse = float('inf')

for alpha in alphas:
    model = MultiTaskLasso(alpha=alpha)
    model.fit(X_train_std, Y_train_std)
    Y_val_pred_std = model.predict(X_val_std)
    Y_val_pred = scaler_Y.inverse_transform(Y_val_pred_std)
    val_rmse = np.sqrt(mean_squared_error(Y_val, Y_val_pred))

    if val_rmse < best_val_rmse:
        best_val_rmse = val_rmse
        best_alpha = alpha

print(f"Best alpha: {best_alpha:.4f}, Validation RMSE: {best_val_rmse:.4f}")


X_trainval_std = scaler_X.transform(np.vstack([X_train, X_val]))
Y_trainval_std = scaler_Y.transform(np.vstack([Y_train, Y_val]))

final_model = MultiTaskLasso(alpha=best_alpha)
final_model.fit(X_trainval_std, Y_trainval_std)
###
joblib.dump(final_model, 'saved_fittings/'+model_name)
final_model = joblib.load('saved_fittings/'+model_name)
Y_test_pred_std = final_model.predict(scaler_X.transform(X_test))
Y_test_pred = scaler_Y.inverse_transform(Y_test_pred_std)

test_rmse = np.sqrt(mean_squared_error(Y_test, Y_test_pred))
print(f"Test RMSE: {test_rmse:.4f}")

X_voc_fit_test = X_test[:, 0::4] 
cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#E59693','#0073B1'])
orig_color_values = np.abs(Y_test[:,:].reshape(-1)*1000-Y_test_pred[:,:].reshape(-1)*1000)
norm = Normalize(vmin=orig_color_values.min(), vmax=orig_color_values.max())
color_values = norm(orig_color_values)
plt.figure(figsize=(5.5 / 2.54,6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
scatter1 = plt.scatter(Y_test[:,:].reshape(-1)*1000, Y_test_pred[:,:].reshape(-1)*1000, 
                      c=color_values, alpha=0.9, cmap='coolwarm', marker='o', linewidth=0.0, s=10, edgecolors=None)
plt.plot(Y_test[:,:].reshape(-1)*1000,Y_test[:,:].reshape(-1)*1000,'--',color='grey',linewidth=1)
plt.xlabel('Real values [mV]')
plt.ylabel('Predictions [mV]')
cbar = plt.colorbar()
cbar.set_label('Absolute error')
plt.tick_params(bottom=False, left=False)
# cbar.set_label('Normalized Color values')
ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
tick_labels = ["{:.2f}".format(value) for value in ticks]
cbar.set_ticks(norm(ticks))
cbar.set_ticklabels(tick_labels)

ax_hist = inset_axes(
    plt.gca(),
    width="40%", height="30%",
    bbox_to_anchor=(0.44, 0.19, 0.62, 0.6),  # x0, y0, width, height
    bbox_transform=plt.gcf().transFigure,
    loc='lower left'
)
ax_hist.hist(
    (Y_test[:, :].reshape(-1) * 1000 - Y_test_pred[:, :].reshape(-1) * 1000),
    bins=20, color='gray', edgecolor='black',linewidth=0.5
)

ax_hist.set_xlabel('Error [mV]', fontsize=6, labelpad=1)  #
ax_hist.set_ylabel('Count', fontsize=6, labelpad=1)
ax_hist.tick_params(axis='both', which='major', labelsize=6)
plt.show()

#%% accuracy comparisons

model_names = ['final_model_C40_test1.pkl','final_model_C5_test1.pkl']
for model_name in model_names:
    print('Test results for:', model_name)
    if model_name=='final_model_C40_test1.pkl':
        real_OCV = all_cell_Vreal[:94]
        fit_OCV = all_OCV_fit[:94]
        Measure_Q = all_cell_Qreal[:94]
        Measure_V = all_cell_Vreal[:94]
    elif model_name=='final_model_C5_test1.pkl':
        real_OCV = all_cell_Vm[94:,:]
        fit_OCV = np.array(all_OCV_fit[94:,:])
        Measure_Q = all_cell_Qm[94:,:]
        Measure_V = all_cell_Vm[94:,:]
    else:
        raise  ValueError("Unknown model")
    
    residual = real_OCV - fit_OCV  # shape: (94, 1000)
    valid_indices = [i for i in range(94) if all_Cq[i] >= 0.7 * norminal_c]

    X_features = []
    Y_targets = []
    for i in valid_indices:
        voc_fit = fit_OCV[i,:]
        voc_real = real_OCV[i,:]
        q_meas = Measure_Q[i,:]
        v_meas = Measure_V[i,:]
        
        cp = all_Cp_opt[i]
        cn = all_Cn_opt[i]
        cq = all_Cq[i]
        cli = all_Cli[i]  # 
        
        X_i = np.column_stack([
            voc_fit,
            # q_meas,
            # v_meas,
            np.full_like(voc_fit, cp),
            np.full_like(voc_fit, cn),
            np.full_like(voc_fit, cli),
        ])
        
        #### flatten 
        X_features.append(X_i.flatten())  # shape: (6000,)
        Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
        
        

    X_all = np.stack(X_features)  # shape: (n_samples, 6000)
    Y_all = np.stack(Y_targets)   # shape: (n_samples, 1000)

    X_temp, X_test, Y_temp, Y_test = train_test_split(X_all, Y_all, test_size=0.2, random_state=42)
    X_train, X_val, Y_train, Y_val = train_test_split(X_temp, Y_temp, test_size=0.25, random_state=42)
    # # .25 x 0.8 = 0.2
    scaler_X = StandardScaler().fit(X_train)
    scaler_Y = StandardScaler().fit(Y_train)

    X_train_std = scaler_X.transform(X_train)
    X_val_std = scaler_X.transform(X_val)
    X_test_std = scaler_X.transform(X_test)

    Y_train_std = scaler_Y.transform(Y_train)
    Y_val_std = scaler_Y.transform(Y_val)
    
    final_model = joblib.load('saved_fittings/'+model_name)

    Y_test_pred_std = final_model.predict(scaler_X.transform(X_test))
    Y_test_pred = scaler_Y.inverse_transform(Y_test_pred_std)
    
    fit_ocv = X_test[:, 0::4]
    orig_ocv = fit_ocv + Y_test
    pre_ocv = fit_ocv + Y_test_pred
    
    
    test_rmse = np.sqrt(mean_squared_error(pre_ocv.reshape(-1)*1000, orig_ocv.reshape(-1)*1000))
    test_maae = np.max(abs(orig_ocv.reshape(-1)*1000- pre_ocv.reshape(-1)*1000))
    test_mae = mean_absolute_error(pre_ocv.reshape(-1)*1000, orig_ocv.reshape(-1)*1000)
    test_r2 = r2_score(orig_ocv.reshape(-1), pre_ocv.reshape(-1))


    fit_rmse = np.sqrt(mean_squared_error(fit_ocv.reshape(-1)*1000, orig_ocv.reshape(-1)*1000))
    fit_maae = np.max(abs(orig_ocv.reshape(-1)*1000- fit_ocv.reshape(-1)*1000))
    fit_mae = mean_absolute_error(fit_ocv.reshape(-1)*1000, orig_ocv.reshape(-1)*1000)
    fit_r2 = r2_score(orig_ocv.reshape(-1), fit_ocv.reshape(-1))
    
    print(f"Test RMSE: {test_rmse:.4f}", f"Test MAE: {test_mae:.4f}", f"Test MaxAE: {test_maae:.4f}", f"Test R2: {test_r2:.4f}")
    print(f"Fit RMSE: {fit_rmse:.4f}", f"Fit MAE: {fit_mae:.4f}", f"Fit MaxAE: {fit_maae:.4f}", f"Fit R2: {fit_r2:.4f}")
    

#%%% train model from C/5 to C/40 and contain residual.
# model_electrode = joblib.load( 'saved_fittings/'+'electrode_C5_to_C40.pkl')

real_OCV = all_cell_Vreal[:94]
fit_OCV = all_cell_Vconstruct # change to the reconstrute OCV based on the predictions of electrode states# all_OCV_fit[:94]
real_Q = all_cell_Qreal[:94]
# Measure_V = all_cell_Vreal[:94]


# real_OCV = all_cell_Vm[94:,:]
# fit_OCV = np.array(all_OCV_fit[94:,:])
Measure_Q = all_cell_Qm[94:,:] # the measurement should be C/5 based
Measure_V = all_cell_Vm[94:,:]


residual = real_OCV - fit_OCV  # shape: (94, 1000)
valid_indices = [i for i in range(94) if all_Cq[i] >= 0.7 * norminal_c]

model_name = 'final_model_C40_test_electrode_prediction_v1.pkl'

print('Train for:',model_name)

X_features = []
Y_targets = []
for i in valid_indices:
    voc_fit = all_cell_ocv_construct[i,:,0]
    voc_real = real_OCV[i,:]
    q_real = real_Q[i,:]
    
    # using the predicted ones
    Cq = all_predictions[i][0][0] /nominal_capacity
    Cp = all_predictions[i][0][1] /nominal_capacity
    Cn = all_predictions[i][0][2] /nominal_capacity
    Cli = all_predictions[i][0][3] /nominal_capacity
    y0 = 0
    x0 = Cli / Cn
    predict_q = np.linspace(0, Cq, 1000)
    
    X_i = np.column_stack([
        voc_fit,
        # predict_q,
        np.full_like(voc_fit, Cp),
        np.full_like(voc_fit, Cn),
        np.full_like(voc_fit, Cli),
    ])
    
    Y_i = np.column_stack([
        voc_real - voc_fit
        # q_real - predict_q
    ])
    
    #### flatten 
    X_features.append(X_i.flatten())  # shape: (6000,)
    # Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
    Y_targets.append(Y_i.flatten())
    if i == valid_indices[0]:
        plt.figure(figsize=(5.5 / 2.54,6 / 2.54), dpi=600)
        plt.ion()
        plt.rcParams['xtick.direction'] = 'in'
        plt.rcParams['ytick.direction'] = 'in'
        plt.tick_params(top='on', right='on', which='both')
        plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
        plt.plot(q_real*norminal_c,'-',color=colors[0],linewidth=1.5)
        plt.plot(predict_q*norminal_c,'--',color=colors[1],linewidth=1.5)
        plt.xlabel('Sample points')
        plt.ylabel('Capacity [Ah]')
        # plt.legend()
        plt.show()

    #%
    
X_all_v = np.stack(X_features)  # shape: (n_samples, 6000)
Y_all_v = np.stack(Y_targets)   # shape: (n_samples, 1000)
print(X_all_v.shape, Y_all_v.shape)

# 6:2:2
num_samples = len(X_all_v)
indices_v = np.arange(num_samples)
X_temp_v, X_test_v, Y_temp_v, Y_test_v, idx_temp_v, idx_test_v = train_test_split(
    X_all_v, Y_all_v, indices_v, test_size=0.2, random_state=42
)

X_train_v, X_val_v, Y_train_v, Y_val_v, idx_train_v, idx_val_v = train_test_split(
    X_temp_v, Y_temp_v, idx_temp_v, test_size=0.25, random_state=42
)

# use for train all
# X_train_v, X_val_v, Y_train_v, Y_val_v = train_test_split(X_all_v, Y_all_v, test_size=0.2, random_state=42)
# X_test_v, Y_test_v = X_val_v, Y_val_v


print(f"Train size: {X_train_v.shape[0]}, Val size: {X_val_v.shape[0]}, Test size: {X_test_v.shape[0]}")

scaler_X_v = StandardScaler().fit(X_train_v)
scaler_Y_v = StandardScaler().fit(Y_train_v)

X_train_std_v = scaler_X_v.transform(X_train_v)
X_val_std_v = scaler_X_v.transform(X_val_v)
X_test_std_v = scaler_X_v.transform(X_test_v)

Y_train_std_v = scaler_Y_v.transform(Y_train_v)
Y_val_std_v = scaler_Y_v.transform(Y_val_v)

# alphas = np.logspace(-3, 1, 30)
# best_alpha = None
# best_val_rmse = float('inf')

# for alpha in alphas:
#     model = MultiTaskLasso(alpha=alpha)
#     model.fit(X_train_std_v, Y_train_std_v)
#     Y_val_pred_std_v = model.predict(X_val_std_v)
#     Y_val_pred_v = scaler_Y_v.inverse_transform(Y_val_pred_std_v)
#     val_rmse_v = np.sqrt(mean_squared_error(Y_val_v, Y_val_pred_v))

#     if val_rmse_v < best_val_rmse:
#         best_val_rmse = val_rmse_v
#         best_alpha = alpha

# print(f"Best alpha: {best_alpha:.4f}, Validation RMSE: {best_val_rmse:.4f}")

# X_trainval_std_v = scaler_X_v.transform(np.vstack([X_train_v, X_val_v]))
# Y_trainval_std_v = scaler_Y_v.transform(np.vstack([Y_train_v, Y_val_v]))

# final_model_v = MultiTaskLasso(alpha=best_alpha)
# final_model_v.fit(X_trainval_std_v, Y_trainval_std_v)
# joblib.dump(final_model_v, 'saved_fittings/'+model_name)
final_model_v = joblib.load('saved_fittings/'+model_name)

Y_test_pred_std_v = final_model_v.predict(scaler_X_v.transform(X_test_v))
Y_test_pred_v = scaler_Y_v.inverse_transform(Y_test_pred_std_v)
Y_test_V = Y_test_pred_v
# Y_test_Q = Y_test_pred[:,1::2]*norminal_c
Y_real_V = Y_test_v
# Y_real_Q = Y_test[:,1::2]*norminal_c
test_rmse = np.sqrt(mean_squared_error(Y_real_V, Y_test_V))
print(f"Test V RMSE: {test_rmse:.4f}")
# test_rmse_q = np.sqrt(mean_squared_error(Y_real_Q, Y_test_Q))
# print(f"Test Q RMSE: {test_rmse_q:.4f}")

X_voc_fit_test_V = X_test_v[:, 0::4] 
# X_voc_fit_test_Q = X_test[:, 1::5]*norminal_c 

cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#E59693','#0073B1'])
orig_color_values = np.abs(Y_real_V[:,:].reshape(-1)*1000-Y_test_V[:,:].reshape(-1)*1000)
norm = Normalize(vmin=orig_color_values.min(), vmax=orig_color_values.max())
color_values = norm(orig_color_values)
plt.figure(figsize=(5.5 / 2.54,6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
scatter1 = plt.scatter(Y_real_V[:,:].reshape(-1)*1000, Y_test_V[:,:].reshape(-1)*1000, 
                      c=color_values, alpha=0.9, cmap='coolwarm', marker='o', linewidth=0.0, s=10, edgecolors=None)
plt.plot(Y_real_V[:,:].reshape(-1)*1000,Y_real_V[:,:].reshape(-1)*1000,'--',color='grey',linewidth=1)
plt.xlabel('Real values [mV]')
plt.ylabel('Predictions [mV]')
cbar = plt.colorbar()
cbar.set_label('Absolute error')
plt.tick_params(bottom=False, left=False)
# cbar.set_label('Normalized Color values')
ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
tick_labels = ["{:.2f}".format(value) for value in ticks]
cbar.set_ticks(norm(ticks))
cbar.set_ticklabels(tick_labels)

ax_hist = inset_axes(
    plt.gca(),
    width="40%", height="30%",
    bbox_to_anchor=(0.44, 0.19, 0.62, 0.6),  
    bbox_transform=plt.gcf().transFigure,
    loc='lower left'
)
ax_hist.hist(
    (Y_real_V[:, :].reshape(-1) * 1000 - Y_test_V[:, :].reshape(-1) * 1000),
    bins=20, color='gray', edgecolor='black',linewidth=0.5
)

ax_hist.set_xlabel('Error [mV]', fontsize=6, labelpad=1)  # 减小 labelpad 值
ax_hist.set_ylabel('Count', fontsize=6, labelpad=1)
ax_hist.tick_params(axis='both', which='major', labelsize=6)
plt.show()


#  Q model 
model_name = 'final_model_C40_test_electrode_prediction_q1.pkl'

print('Train for:',model_name)

X_features = []
Y_targets = []
for i in valid_indices:
    voc_fit = all_cell_ocv_construct[i,:,0]
    voc_real = real_OCV[i,:]
    q_real = real_Q[i,:]
    
    # using the predicted ones
    Cq = all_predictions[i][0][0] /nominal_capacity
    Cp = all_predictions[i][0][1] /nominal_capacity
    Cn = all_predictions[i][0][2] /nominal_capacity
    Cli = all_predictions[i][0][3] /nominal_capacity
    y0 = 0
    x0 = Cli / Cn
    predict_q = np.linspace(0, Cq, 1000)
    
    X_i = np.column_stack([
        # voc_fit,
        predict_q,
        np.full_like(voc_fit, Cp),
        np.full_like(voc_fit, Cn),
        np.full_like(voc_fit, Cli),
    ])
    
    Y_i = np.column_stack([
        # voc_real - voc_fit
        q_real - predict_q
    ])
    
    #### flatten 
    X_features.append(X_i.flatten())  # shape: (6000,)
    # Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
    Y_targets.append(Y_i.flatten())
    
    
X_all_q = np.stack(X_features)  # shape: (n_samples, 6000)
Y_all_q = np.stack(Y_targets)   # shape: (n_samples, 1000)
print(X_all_q.shape, Y_all_q.shape)

num_samples = len(X_all_q)
indices_q = np.arange(num_samples)

X_temp_q, X_test_q, Y_temp_q, Y_test_q, idx_temp_q, idx_test_q = train_test_split(
    X_all_q, Y_all_q, indices_q, test_size=0.2, random_state=42
)

X_train_q, X_val_q, Y_train_q, Y_val_q, idx_train_q, idx_val_q = train_test_split(
    X_temp_q, Y_temp_q, idx_temp_q, test_size=0.25, random_state=42
)


# use for train all
# X_train_q, X_val_q, Y_train_q, Y_val_q = train_test_split(X_all_q, Y_all_q, test_size=0.2, random_state=42)
# X_test_q, Y_test_q = X_val_q, Y_val_q


print(f"Train size: {X_train_q.shape[0]}, Val size: {X_val_q.shape[0]}, Test size: {X_test_q.shape[0]}")

scaler_X_q = StandardScaler().fit(X_train_q)
scaler_Y_q = StandardScaler().fit(Y_train_q)

X_train_std_q = scaler_X_q.transform(X_train_q)
X_val_std_q = scaler_X_q.transform(X_val_q)
X_test_std_q = scaler_X_q.transform(X_test_q)

Y_train_std_q = scaler_Y_q.transform(Y_train_q)
Y_val_std_q = scaler_Y_q.transform(Y_val_q)


# alphas = np.logspace(-3, 1, 30)
# best_alpha = None
# best_val_rmse = float('inf')

# for alpha in alphas:
#     model = MultiTaskLasso(alpha=alpha)
#     model.fit(X_train_std_q, Y_train_std_q)
#     Y_val_pred_std_q = model.predict(X_val_std_q)
#     Y_val_pred_q = scaler_Y_q.inverse_transform(Y_val_pred_std_q)
#     val_rmse_q = np.sqrt(mean_squared_error(Y_val_q, Y_val_pred_q))

#     if val_rmse_q < best_val_rmse:
#         best_val_rmse = val_rmse_q
#         best_alpha = alpha

# print(f"Best alpha: {best_alpha:.4f}, Validation RMSE: {best_val_rmse:.4f}")


# X_trainval_std_q = scaler_X_q.transform(np.vstack([X_train_q, X_val_q]))
# Y_trainval_std_q = scaler_Y_q.transform(np.vstack([Y_train_q, Y_val_q]))

# final_model_q = MultiTaskLasso(alpha=best_alpha)
# final_model_q.fit(X_trainval_std_q, Y_trainval_std_q)

# joblib.dump(final_model_q, 'saved_fittings/'+model_name)
final_model_q = joblib.load('saved_fittings/'+model_name)

Y_test_pred_std_q = final_model_q.predict(scaler_X_q.transform(X_test_q))
Y_test_pred_q = scaler_Y_q.inverse_transform(Y_test_pred_std_q)
Y_test_Q = Y_test_pred_q
# Y_test_Q = Y_test_pred[:,1::2]*norminal_c
Y_real_Q = Y_test_q
# Y_real_Q = Y_test[:,1::2]*norminal_c
# test_rmse = np.sqrt(mean_squared_error(Y_real_V, Y_test_V))
# print(f"Test V RMSE: {test_rmse:.4f}")
test_rmse_q = np.sqrt(mean_squared_error(Y_real_Q, Y_test_Q))
print(f"Test Q RMSE: {test_rmse_q:.4f}")

X_voc_fit_test_Q = X_test_q[:, 0::4] 

cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#E59693','#0073B1'])
orig_color_values = np.abs(Y_real_Q[:,:].reshape(-1)*norminal_c*1000-Y_test_Q[:,:].reshape(-1)*norminal_c*1000)
norm = Normalize(vmin=orig_color_values.min(), vmax=orig_color_values.max())
color_values = norm(orig_color_values)
plt.figure(figsize=(5.5 / 2.54,6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
scatter1 = plt.scatter(Y_real_Q[:,:].reshape(-1)*norminal_c*1000, Y_test_Q[:,:].reshape(-1)*norminal_c*1000, 
                      c=color_values, alpha=0.9, cmap='coolwarm', marker='o', linewidth=0.0, s=10, edgecolors=None)
plt.plot(Y_real_Q[:,:].reshape(-1)*norminal_c*1000,Y_real_Q[:,:].reshape(-1)*norminal_c*1000,'--',color='grey',linewidth=1)
plt.xlabel('Real values [mAh]')
plt.ylabel('Predictions [mAh]')
cbar = plt.colorbar()
cbar.set_label('Absolute error')
plt.tick_params(bottom=False, left=False)
# cbar.set_label('Normalized Color values')
ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
tick_labels = ["{:.2f}".format(value) for value in ticks]
cbar.set_ticks(norm(ticks))
cbar.set_ticklabels(tick_labels)

ax_hist = inset_axes(
    plt.gca(),
    width="40%", height="30%",
    bbox_to_anchor=(0.44, 0.19, 0.62, 0.6),  
    bbox_transform=plt.gcf().transFigure,
    loc='lower left'
)
ax_hist.hist(
    (Y_real_Q[:, :].reshape(-1) * norminal_c*1000 - Y_test_Q[:, :].reshape(-1) * norminal_c*1000),
    bins=20, color='gray', edgecolor='black',linewidth=0.5
)

ax_hist.set_xlabel('Error [mAh]', fontsize=6, labelpad=1)  #
ax_hist.set_ylabel('Count', fontsize=6, labelpad=1)
ax_hist.tick_params(axis='both', which='major', labelsize=6)
plt.show()



test_rmse = np.sqrt(mean_squared_error(Y_real_Q.reshape(-1)* norminal_c*1000, Y_test_Q.reshape(-1)* norminal_c*1000))
test_maae = np.max(abs(Y_real_Q.reshape(-1)* norminal_c*1000- Y_test_Q.reshape(-1)* norminal_c*1000))
test_mae = mean_absolute_error(Y_real_Q.reshape(-1)* norminal_c*1000, Y_test_Q.reshape(-1)* norminal_c*1000)
test_r2 = r2_score(Y_real_Q.reshape(-1)* norminal_c, Y_test_Q.reshape(-1)* norminal_c)


# fit_rmse = np.sqrt(mean_squared_error(Y_real_Q.reshape(-1)* norminal_c*1000, 0*Y_test_Q.reshape(-1)* norminal_c*1000))
# fit_maae = np.max(abs(Y_real_Q.reshape(-1)* norminal_c*1000- 0*Y_test_Q.reshape(-1)* norminal_c*1000))
# fit_mae = mean_absolute_error(Y_real_Q.reshape(-1)* norminal_c*1000, 0*Y_test_Q.reshape(-1)* norminal_c*1000)
# fit_r2 = r2_score(Y_real_Q.reshape(-1)* norminal_c, 0*Y_test_Q.reshape(-1)* norminal_c)

print(f"Test RMSE Q: {test_rmse:.4f}", f"Test MAE: {test_mae:.4f}", f"Test MaxAE: {test_maae:.4f}", f"Test R2: {test_r2:.4f}")
# print(f"Fit RMSE: {fit_rmse:.4f}", f"Fit MAE: {fit_mae:.4f}", f"Fit MaxAE: {fit_maae:.4f}", f"Fit R2: {fit_r2:.4f}")

test_rmse = np.sqrt(mean_squared_error(Y_real_V.reshape(-1)*1000, Y_test_V.reshape(-1)*1000))
test_maae = np.max(abs(Y_real_V.reshape(-1)*1000- Y_test_V.reshape(-1)*1000))
test_mae = mean_absolute_error(Y_real_V.reshape(-1)*1000, Y_test_V.reshape(-1)*1000)
test_r2 = r2_score(Y_real_V.reshape(-1), Y_test_V.reshape(-1))
print(f"Test RMSE V: {test_rmse:.4f}", f"Test MAE: {test_mae:.4f}", f"Test MaxAE: {test_maae:.4f}", f"Test R2: {test_r2:.4f}")


#%%
fig2, ax2 = plt.subplots(figsize=(6/2.54, 6/2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)

sns.kdeplot(
    (-real_Q[idx_test_q].reshape(-1) * norminal_c*1000 + Measure_Q[idx_test_q].reshape(-1) * norminal_c*1000),
    ax=ax2, label='Measured C/5', color=colors[5], linewidth=1.5, fill=True
)

sns.kdeplot(
    (-Y_real_Q.reshape(-1) * norminal_c*1000),
    ax=ax2, label='Predicted C/40', color=colors[0], linewidth=1.5, fill=True
)

sns.kdeplot(
    (-Y_real_Q.reshape(-1) * norminal_c*1000 + Y_test_Q.reshape(-1) * norminal_c*1000),
    ax=ax2, label='Compensated', color=colors[1], linewidth=1.5, fill=True
)

ax_hist = inset_axes(
    plt.gca(),
    width="40%", height="30%",
    bbox_to_anchor=(0.6, 0.4, 0.6, 1),
    bbox_transform=plt.gcf().transFigure,
    loc='lower left'
)
ax_hist.hist(
    (Y_real_Q[:, :].reshape(-1) * norminal_c*1000 - Y_test_Q[:, :].reshape(-1) * norminal_c*1000),
    bins=20, color=colors[1], edgecolor='black',linewidth=0.5
)

ax_hist.set_xlabel('Error [mAh]', fontsize=6, labelpad=1)  # 
ax_hist.set_ylabel('Count', fontsize=6, labelpad=1)
ax_hist.tick_params(axis='both', which='major', labelsize=6)

ax2.set_xlabel('Error [mAh]')
ax2.set_ylabel('Density')
ax2.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
# ax2.legend(loc='upper right', frameon=False)
plt.show()


fig2, ax2 = plt.subplots(figsize=(6/2.54, 6/2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)

sns.kdeplot(
    (-real_OCV[idx_test_q].reshape(-1) *1000 + Measure_V[idx_test_q].reshape(-1) *1000),
    ax=ax2, label='Measured C/5', color=colors[5], linewidth=1.5, fill=True
)

sns.kdeplot(
    (-Y_real_V.reshape(-1) *1000),
    ax=ax2, label='Predicted C/40', color=colors[0], linewidth=1.5, fill=True
)

sns.kdeplot(
    (-Y_real_V.reshape(-1) *1000 + Y_test_V.reshape(-1) *1000),
    ax=ax2, label='Compensated', color=colors[1], linewidth=1.5, fill=True
)
ax_hist = inset_axes(
    plt.gca(),
    width="40%", height="30%",
    bbox_to_anchor=(0.25, 0.4, 0.6, 1),
    bbox_transform=plt.gcf().transFigure,
    loc='lower left'
)
ax_hist.hist(
    (Y_real_V[:, :].reshape(-1) * 1000 - Y_test_V[:, :].reshape(-1) * 1000),
    bins=20, color=colors[1], edgecolor='black',linewidth=0.5
)

ax_hist.set_xlabel('Error [mV]', fontsize=6, labelpad=1)  
ax_hist.set_ylabel('Count', fontsize=6, labelpad=1)
ax_hist.tick_params(axis='both', which='major', labelsize=6)

ax2.set_xlabel('Error [mV]')
ax2.set_ylabel('Density')
ax2.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
# ax2.legend(loc='upper right', frameon=False)
plt.show()


#%%

Compen_ocv = Y_test_V+X_voc_fit_test_V
Compen_Q = Y_test_Q+X_voc_fit_test_Q
Compen_Q = Compen_Q*norminal_c
# idx_test_run = idx_test[11:12]
plt.figure(figsize=(10 / 2.54,6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)

i_idx =15  #10  15 4 18
i = idx_test_v[i_idx]
# for i_idx, i in enumerate(idx_test_run):
    # print(i_idx)
measured_q = Measure_Q[valid_indices][i]*norminal_c
# measured_q = np.linspace(0, measured_q[-1], 1000)
measured_v = Measure_V[valid_indices][i]
real_q = real_Q[valid_indices][i]*norminal_c
# real_q = np.linspace(0, real_q[-1], 1000)
real_v = real_OCV[valid_indices][i]

predict_q = X_voc_fit_test_Q[i_idx]*norminal_c
# predict_q = np.linspace(0, predict_q, 1000)
predict_v = X_voc_fit_test_V[i_idx]

compensate_v = Compen_ocv[i_idx]
compensate_q = Compen_Q[i_idx]
# OCV_recon = Reconst_ocv[idx]
# OCV_compen = Compen_ocv[idx]
# dv_dq_compensate = gradient( OCV_compen, measured_Q )
# dv_dq_recon = gradient( OCV_recon, measured_Q )
dv_dq_measured = gradient( measured_v, measured_q )
dv_dq_real = gradient( real_v, real_q )
dv_dq_predict = gradient( predict_v, predict_q )
dv_dq_compensate = gradient( compensate_v, compensate_q )

# plt.plot(measured_q, measured_v,'--', alpha=0.8, color=colors[0],label=f'C/5 Measured')
# plt.plot(real_q,real_v, '-', alpha=0.8, color=colors[1], label=f'C/40 Real' )
# plt.plot(predict_q, predict_v,'--', color=colors[6], alpha=0.8, label=f'C/40 Reconstructed' )
# plt.plot(compensate_q, compensate_v,'--', color=colors[4], alpha=0.8, label=f'C/40 Compensation')

plt.plot(measured_q, -dv_dq_measured,'--', alpha=0.8, color=colors[5],label=f'C/5 Measured')
plt.plot(real_q,-dv_dq_real, '-', alpha=0.8, color=colors[4], label=f'C/40 Real' )
plt.plot(predict_q, -dv_dq_predict,'--', color=colors[0], alpha=0.8, label=f'C/40 Reconstructed' )
plt.plot(compensate_q, -dv_dq_compensate,'--', color=colors[1], alpha=0.8, label=f'C/40 Compensation')
plt.ylim([0,1])

plt.xlabel('Q [Ah]')
plt.ylabel('dV/dQ [V/Ah]')
# plt.grid(True)

plt.legend(loc='best',
handletextpad=0.1, 
labelspacing=0.05,
frameon=False,
fontsize=10)
plt.show()


